# Copyright 2022 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import json
import math
import os
import signal
import subprocess
import sys
import tempfile
import time
from copy import deepcopy
from functools import partial
from typing import Callable

import dill
from datatrove.executor.base import PipelineExecutor
from datatrove.io import DataFolderLike
from datatrove.pipeline.base import PipelineStep
from datatrove.utils.logging import get_random_str, get_timestamp, logger
from dill import CONTENTS_FMODE


def requeue_handler(signum, _frame):
    signame = signal.Signals(signum).name
    logger.warning(
        f"Received signal {signum} ({signame}). Requeueing and exiting..."
    )
    subprocess.run(["scontrol", "requeue", os.environ.get("SLURM_JOB_ID")])
    sys.exit(15)


class SlurmPipelineExecutor(PipelineExecutor):
    """Execute a pipeline on a slurm cluster
    Creates and calls a sbatch launch script.

    [!] do not launch tasks from within a compute node/from another slurm task!

    Args:
        pipeline: a list of PipelineStep and/or custom functions
            with arguments (data: DocumentsPipeline, rank: int,
            world_size: int)
        tasks: total number of tasks to run the pipeline on
        time: slurm time limit
        partition: slurm partition
        cpus_per_task: how many cpus to give each task. should be 1
            except when you need to give each task more memory
        mem_per_cpu_gb: slurm option. use in conjunction with the
            above option to increase max memory
        workers: how many tasks to run simultaneously. -1 for no
            limit
        job_name: slurm job name
        env_command: command to activate a python environment, if
            needed
        condaenv: name of a conda environment to activate
        venv_path: path to a python venv to activate
        container_path: path to the singularity container to use
        sbatch_args: dictionary with additional arguments to pass to
            sbatch
        max_array_size: the limit of tasks in a task array job on
            your slurm cluster or -1 if none. if
            tasks>max_array_size, multiple task array jobs will be
            launched
        depends: another SlurmPipelineExecutor that should run
            before this one
        depends_job_id: alternatively to the above, you can pass the job id of a dependency
        job_id_position: position of job ID in custom Sbatch outputs.
            default: -1
        job_id_retriever: a callable that takes the output of the sbatch command (as written to terminal)
            as input and returns the extracted job id to be used as 'self.job_id'. Defaults to take the
            `job_id_position`-th element from the split output. default: default_job_id_retriever
        logging_dir: where to save logs, stats, etc. Should be parsable into a datatrove.io.DataFolder
        skip_completed: whether to skip tasks that were completed in
            previous runs. default: True
        slurm_logs_folder: where to store the raw slurm log files.
            must be a local path default:
            slurm_logs/$job_name/$timestamp_$randomstring
        max_array_launch_parallel: if we need multiple jobs due to max_array_size, whether to launch them all in
            one go (parallel) or sequentially
        stagger_max_array_jobs: when max_array_launch_parallel is True, this determines how many seconds to wait
            between launching each of the parallel jobs
        run_on_dependency_fail: start executing when a job we depend on finishes even if it has failed
        randomize_start_duration: the maximum number of seconds to delay the start of each task.
        requeue_signals: requeue the job and exit when one of these signals is received. Useful for when an instance
        is being reclaimed and jobs must be stopped for example. Set to None to disable
        mail_type: see https://slurm.schedmd.com/sbatch.html. Common values are (NONE, BEGIN, END, FAIL, REQUEUE, ALL)
        mail_user: email address to send notifications to
        requeue: requeue the job if it fails
        tasks_per_job: each slurm job in the job array will run these many datatrove tasks. This reduces the total nb of slurm jobs launched.
    """

    def __init__(
        self,
        pipeline: list[PipelineStep | Callable],
        tasks: int,
        time: str,
        partition: str | None = None,
        cpus_per_task: int = 1,
        mem_per_cpu_gb: int = 2,
        workers: int = -1,
        job_name: str = "data_processing",
        qos: str = "normal",
        env_command: str = None,
        condaenv: str = None,
        venv_path: str = None,
        container_path: str = None,
        sbatch_args: dict | None = None,
        max_array_size: int = 1001,
        depends: SlurmPipelineExecutor | None = None,
        depends_job_id: str | None = None,
        job_id_position: int = -1,
        job_id_retriever: Callable = None,
        logging_dir: DataFolderLike = None,
        skip_completed: bool = True,
        slurm_logs_folder: str = None,
        max_array_launch_parallel: bool = False,
        stagger_max_array_jobs: int = 0,
        run_on_dependency_fail: bool = False,
        randomize_start_duration: int = 0,
        requeue_signals: tuple[str] | None = ("SIGUSR1",),
        mail_type: str = "ALL",
        mail_user: str = None,
        requeue: bool = True,
        srun_args: dict = None,
        tasks_per_job: int = 1,
        use_gpu: bool = False,
    ):
        super().__init__(
            pipeline, logging_dir, skip_completed, randomize_start_duration
        )
        self.tasks = tasks
        self.workers = workers
        self.partition = partition
        self.cpus_per_task = cpus_per_task
        self.mem_per_cpu_gb = mem_per_cpu_gb
        self.tasks_per_job = tasks_per_job
        self.time = time
        self.job_name = job_name
        self.qos = qos
        self.env_command = env_command
        self.condaenv = condaenv
        self.venv_path = venv_path
        self.container_path = container_path
        if self.container_path is None:
            raise ValueError(
                "Container path must be set. Please set the container_path argument."
            )
        self.depends = depends
        self.depends_job_id = depends_job_id
        self.job_id_position = job_id_position

        if job_id_retriever is None:
            job_id_retriever = partial(
                default_job_id_retriever, job_id_position=job_id_position
            )
        self.job_id_retriever = job_id_retriever

        self._sbatch_args = sbatch_args if sbatch_args else {}
        self.max_array_size = max_array_size
        self.max_array_launch_parallel = max_array_launch_parallel
        self.stagger_max_array_jobs = stagger_max_array_jobs
        self.run_on_dependency_fail = run_on_dependency_fail
        self.randomize_start_duration = randomize_start_duration
        self.job_id = None
        self.requeue_signals = requeue_signals
        self.mail_type = mail_type
        self.mail_user = mail_user
        self.srun_args = srun_args
        self.use_gpu = use_gpu
        self.slurm_logs_folder = (
            slurm_logs_folder
            if slurm_logs_folder
            else (
                f"slurm_logs/{self.job_name}/{get_timestamp()}_{get_random_str()}"
                if not self.logging_dir.is_local()
                else self.logging_dir.resolve_paths("slurm_logs")
            )
        )
        self.requeue = requeue

    def run(self):
        """
            This method is responsible for correctly invoking `self._run_for_rank` for each task that is to be run.

            On a SlurmPipelineExecutor, we first check if we are already running on a slurm task and, if not, we launch it.
        Returns:

        """
        if "SLURM_ARRAY_TASK_ID" in os.environ:
            # we are already "inside" the slurm task, get our rank from env vars and run pipeline
            slurm_rank = int(
                os.environ["SLURM_ARRAY_TASK_ID"]
            ) + self.max_array_size * int(os.environ.get("RUN_OFFSET", 0))
            ranks_to_run_range = (
                slurm_rank * self.tasks_per_job,
                (slurm_rank + 1) * self.tasks_per_job,
            )
            with self.logging_dir.open(
                "ranks_to_run.json", "r"
            ) as ranks_to_run_file:
                all_ranks = json.load(ranks_to_run_file)
            if ranks_to_run_range[0] >= len(all_ranks):
                return

            for ss in self.requeue_signals or []:
                signal.signal(signal.Signals[ss], requeue_handler)

            for rank_to_run in range(*ranks_to_run_range):
                if rank_to_run >= len(all_ranks):
                    break
                rank = all_ranks[rank_to_run]

                self._run_for_rank(rank)
        else:
            # we still have to launch the job
            self.launch_job()

    def launch_merge_stats(self):
        """
            Launch a slurm task to merge the stats of each individual task into one big stats summary file.
        Returns:

        """
        nv_arg = "--nv " if self.use_gpu else ""
        merge_cmd = (
            f"singularity exec {nv_arg}{self.container_path} merge_stats "
            f"{self.logging_dir.resolve_paths('stats')} "
            f"-o {self.logging_dir.resolve_paths('stats.json')}"
        )
        launch_slurm_job(
            self.get_launch_file_contents(
                {
                    **self.get_sbatch_args(),
                    "cpus-per-task": 1,
                    "mem-per-cpu": "1G",
                    "dependency": f"afterok:{self.job_id}",
                },
                merge_cmd,
            ),
            self.job_id_retriever,
        )

    @property
    def dependency(self) -> str:
        """
            Resolve list of jobs we depend on and return it as a slurm string.
        Returns:

        """
        dependency = []
        if self.depends_job_id:
            dependency.append(
                f"{'afterany' if self.run_on_dependency_fail else 'afterok'}:{self.depends_job_id}"
            )
        if self.job_id and not self.max_array_launch_parallel:
            dependency.append(f"afterany:{self.job_id}")
        return ",".join(dependency)

    def launch_job(self):
        """
            Takes care of creating a sbatch script for this pipeline and launching it.
        Returns:

        """
        assert not self.depends or (
            isinstance(self.depends, SlurmPipelineExecutor)
        ), "depends= must be a SlurmPipelineExecutor"
        if self.depends:
            # take care of launching any unlaunched dependencies and getting their slurm job ids
            if not self.depends.job_id:
                logger.info(
                    f'Launching dependency job "{self.depends.job_name}"'
                )
                self.depends.launch_job()
            if self.depends.job_id != -1:
                self.depends_job_id = self.depends.job_id
            self.depends = None  # avoid pickling the entire dependency and possibly its dependencies

        ranks_to_run = self.get_incomplete_ranks()
        if len(ranks_to_run) == 0:
            logger.info(
                f"Skipping launch of {self.job_name} as all {self.tasks} tasks have already been completed."
            )
            self.job_id = -1
            return

        executor = deepcopy(self)

        # pickle. The slurm job will load the executor from this pik file
        with self.logging_dir.open("executor.pik", "wb") as executor_f:
            dill.dump(executor, executor_f, fmode=CONTENTS_FMODE)
        self.save_executor_as_json()

        with self.logging_dir.open(
            "ranks_to_run.json", "w"
        ) as ranks_to_run_file:
            # we actually save this (only once) to avoid race conditions
            json.dump(ranks_to_run, ranks_to_run_file)

        nb_jobs_to_launch = math.ceil(len(ranks_to_run) / self.tasks_per_job)
        max_array = (
            min(nb_jobs_to_launch, self.max_array_size)
            if self.max_array_size != -1
            else nb_jobs_to_launch
        )

        # create the actual sbatch script
        # changed this to call singularity image inside slurm nodes
        srun_args_str = (
            " ".join([f"--{k}={v}" for k, v in self.srun_args.items()])
            if self.srun_args
            else ""
        )
        nv_arg = "--nv " if self.use_gpu else ""
        srun_cmd = (
            f"srun {srun_args_str} -l -n 1 singularity exec {nv_arg}"
            f"{self.container_path} launch_pickled_pipeline "
            f"{self.logging_dir.resolve_paths('executor.pik')}"
        )
        # we need to add these paths because usually they are inherited from our dev env
        run_script = "\n".join(
            [
                "export PYTHONPATH=/modelzoo/data_preparation/:$PYTHONPATH",
                "export PATH=/opt/slurm/bin:$PATH",
                srun_cmd,
            ]
        )
        launch_file_contents = self.get_launch_file_contents(
            self.get_sbatch_args(max_array), run_script
        )

        # save it
        with self.logging_dir.open(
            "launch_script.slurm", "w"
        ) as launchscript_f:
            launchscript_f.write(launch_file_contents)
        logger.info(
            f"Launching Slurm job {self.job_name} ({len(ranks_to_run)} tasks) with launch script "
            f'"{self.logging_dir.resolve_paths("launch_script.slurm")}"'
        )

        # launch (possibly multiple) jobs
        launched_jobs = 0
        while launched_jobs * max_array < nb_jobs_to_launch:
            if (
                launched_jobs
                and self.max_array_launch_parallel
                and self.stagger_max_array_jobs > 0
            ):
                time.sleep(self.stagger_max_array_jobs)
            args = [f"--export=ALL,RUN_OFFSET={launched_jobs}"]
            if self.dependency:
                args.append(f"--dependency={self.dependency}")
            self.job_id = launch_slurm_job(
                launch_file_contents, self.job_id_retriever, *args
            )
            launched_jobs += 1
        logger.info(
            f"Slurm job launched successfully with (last) id={self.job_id}."
        )
        self.launch_merge_stats()

    def get_sbatch_args(self, max_array: int = 1) -> dict:
        """
            Get a dictionary with all the sbatch directives we want to include
        Args:
            max_array: max array size

        Returns: a dictionary with all the sbatch directives

        """
        # this one we actually have to create as slurm will be writing here
        slurm_logfile = os.path.abspath(
            os.path.join(self.slurm_logs_folder, "%A_%a.out")
        )

        sbatch_args = {
            "cpus-per-task": self.cpus_per_task,
            "mem-per-cpu": f"{self.mem_per_cpu_gb}G",
            **({"partition": self.partition} if self.partition else {}),
            "job-name": self.job_name,
            "time": self.time,
            "output": slurm_logfile,
            "error": slurm_logfile,
            "array": f"0-{max_array - 1}{f'%{self.workers}' if self.workers != -1 else ''}",
            **(
                {"mail-type": self.mail_type, "mail-user": self.mail_user}
                if self.mail_user
                else {}
            ),
            **self._sbatch_args,
        }
        if self.requeue:
            sbatch_args["requeue"] = ""
        if self.qos:
            sbatch_args["qos"] = self.qos
        return sbatch_args

    def get_launch_file_contents(
        self, sbatch_args: dict, run_script: str
    ) -> str:
        """
        Actually generate the sbatch script
        Args:
            sbatch_args: dictionary with all the sbatch directives to include
            run_script: command to be invoked by this slurm job

        Returns:
            Full launch script content
        """
        args = "\n".join(
            [
                f"#SBATCH --{k}={v}" if v else f"#SBATCH --{k}"
                for k, v in sbatch_args.items()
            ]
        )

        if self.env_command:
            env_command = self.env_command
        elif self.condaenv:
            env_command = "\n".join(
                [
                    "conda init bash",
                    f"conda activate {self.condaenv}",
                    "source ~/.bashrc",
                ]
            )
        elif self.venv_path:
            env_command = f"source {self.venv_path}"
        else:
            env_command = ""

        lines = [
            "#!/bin/bash",
            args,
            f'echo "Starting data processing job {self.job_name}"',
            env_command,
            "set -xe",
            "export PYTHONUNBUFFERED=TRUE",
            run_script,
        ]

        return "\n".join(lines) + "\n"

    @property
    def world_size(self) -> int:
        return self.tasks


def launch_slurm_job(
    launch_file_contents, job_id_retriever: Callable, *args
) -> str:
    """
        Small helper function to save a sbatch script and call it.
    Args:
        launch_file_contents: Contents of the sbatch script
        job_id_position: Index of dependecy job ID.
        *args: any other arguments to pass to the sbatch command

    Returns: the id of the launched slurm job

    """
    with tempfile.NamedTemporaryFile("w") as f:
        f.write(launch_file_contents)
        f.flush()
        process_output = subprocess.check_output(
            ["sbatch", *args, f.name]
        ).decode("utf-8")
        return job_id_retriever(process_output)


def default_job_id_retriever(process_output: str, job_id_position: int):
    return process_output.split()[job_id_position]
